{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# **ResNet50 V1**\n",
    "\n",
    "This assumes that our toolkits and its base requirements have been met, including access to the ImageNet dataset. Please refer to [\"Requirements\"](https://gitlab-master.nvidia.com/sagshelke/tensorrt_qat/-/tree/main/examples#requirements) in the `examples` folder."
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 1. Initial settings"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "outputs": [],
   "source": [
    "import os\n",
    "import tensorflow as tf\n",
    "from tensorflow_quantization.quantize import quantize_model\n",
    "from tensorflow_quantization.custom_qdq_cases import ResNetV1QDQCase\n",
    "from tensorflow_quantization.utils import convert_saved_model_to_onnx"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "HYPERPARAMS = {\n",
    "    \"tfrecord_data_dir\": \"/media/Data/ImageNet/train-val-tfrecord\",\n",
    "    \"batch_size\": 64,\n",
    "    \"epochs\": 2,\n",
    "    \"steps_per_epoch\": 500,\n",
    "    \"train_data_size\": None,\n",
    "    \"val_data_size\": None,\n",
    "    \"save_root_dir\": \"./weights/resnet_50v1_jupyter\"\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Load data"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from examples.data.data_loader import load_data\n",
    "train_batches, val_batches = load_data(HYPERPARAMS, model_name=\"resnet_v1\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## 2. Baseline model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Instantiate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "model = tf.keras.applications.ResNet50(\n",
    "    include_top=True,\n",
    "    weights=\"imagenet\",\n",
    "    classes=1000,\n",
    "    classifier_activation=\"softmax\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Evaluate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "781/781 [==============================] - 41s 51ms/step - loss: 1.0481 - accuracy: 0.7504\n",
      "Baseline val accuracy: 75.044%\n"
     ]
    }
   ],
   "source": [
    "def compile_model(model, lr=0.001):\n",
    "    model.compile(\n",
    "        optimizer=tf.keras.optimizers.SGD(learning_rate=lr),\n",
    "        loss=tf.keras.losses.SparseCategoricalCrossentropy(),\n",
    "        metrics=[\"accuracy\"],\n",
    "    )\n",
    "\n",
    "compile_model(model)\n",
    "_, baseline_model_accuracy = model.evaluate(val_batches)\n",
    "print(\"Baseline val accuracy: {:.3f}%\".format(baseline_model_accuracy*100))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Save and convert to ONNX"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: ./weights/resnet_50v1_jupyter/saved_model_baseline/assets\n",
      "ONNX conversion Done!\n"
     ]
    }
   ],
   "source": [
    "model_save_path = os.path.join(HYPERPARAMS[\"save_root_dir\"], \"saved_model_baseline\")\n",
    "model.save(model_save_path)\n",
    "convert_saved_model_to_onnx(saved_model_dir=model_save_path,\n",
    "                            onnx_model_path=model_save_path + \".onnx\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## 3. Quantization-Aware Training model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Quantize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "q_model = quantize_model(model, custom_qdq_cases=[ResNetV1QDQCase()])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Fine-tune"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/2\n",
      "500/500 [==============================] - 425s 838ms/step - loss: 0.4075 - accuracy: 0.8898 - val_loss: 1.0451 - val_accuracy: 0.7497\n",
      "Epoch 2/2\n",
      "500/500 [==============================] - 420s 840ms/step - loss: 0.3960 - accuracy: 0.8918 - val_loss: 1.0392 - val_accuracy: 0.7511\n"
     ]
    },
    {
     "data": {
      "text/plain": "<keras.callbacks.History at 0x7f9cec1e60d0>"
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "compile_model(q_model)\n",
    "q_model.fit(\n",
    "    train_batches,\n",
    "    validation_data=val_batches,\n",
    "    batch_size=HYPERPARAMS[\"batch_size\"],\n",
    "    steps_per_epoch=HYPERPARAMS[\"steps_per_epoch\"],\n",
    "    epochs=HYPERPARAMS[\"epochs\"]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Evaluate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "781/781 [==============================] - 179s 229ms/step - loss: 1.0392 - accuracy: 0.7511\n",
      "QAT val accuracy: 75.114%\n"
     ]
    }
   ],
   "source": [
    "_, qat_model_accuracy = q_model.evaluate(val_batches)\n",
    "print(\"QAT val accuracy: {:.3f}%\".format(qat_model_accuracy*100))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Save and convert to ONNX"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:absl:Found untraced functions such as conv1_conv_layer_call_fn, conv1_conv_layer_call_and_return_conditional_losses, conv2_block1_1_conv_layer_call_fn, conv2_block1_1_conv_layer_call_and_return_conditional_losses, conv2_block1_2_conv_layer_call_fn while saving (showing 5 of 140). These functions will not be directly callable after loading.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: ./weights/resnet_50v1_jupyter/saved_model_qat/assets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: ./weights/resnet_50v1_jupyter/saved_model_qat/assets\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ONNX conversion Done!\n"
     ]
    }
   ],
   "source": [
    "q_model_save_path = os.path.join(HYPERPARAMS[\"save_root_dir\"], \"saved_model_qat\")\n",
    "q_model.save(q_model_save_path)\n",
    "convert_saved_model_to_onnx(saved_model_dir=q_model_save_path,\n",
    "                            onnx_model_path=q_model_save_path + \".onnx\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## 4. QAT vs Baseline comparison"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Baseline vs QAT: 75.044% vs 75.114%\n",
      "Accuracy difference of +0.070%\n"
     ]
    }
   ],
   "source": [
    "print(\"Baseline vs QAT: {:.3f}% vs {:.3f}%\".format(baseline_model_accuracy*100, qat_model_accuracy*100))\n",
    "\n",
    "acc_diff = (qat_model_accuracy - baseline_model_accuracy)*100\n",
    "acc_diff_sign = \"\" if acc_diff == 0 else (\"-\" if acc_diff < 0 else \"+\")\n",
    "print(\"Accuracy difference of {}{:.3f}%\".format(acc_diff_sign, abs(acc_diff)))"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "```{note}\n",
    "\n",
    "For full workflow, including TensorRT(TM) deployment, please refer to [examples/resnet](https://github.com/NVIDIA/TensorRT/tree/main/tools/tensorflow-quantization/examples/resnet).\n",
    "\n",
    "```"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}